# isort: off
# fmt: off
from dataclasses import dataclass

import triton
from triton_kernels import target_info
from triton_kernels.target_info import get_cdna_version
from triton_kernels.tensor import FP4
import torch
from triton_kernels.tensor_details.layout_details.hopper_scale import HopperMXScaleLayout
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
from triton_kernels.tensor import bitwidth, get_layout

@dataclass
class OptFlags:
    block_m: int
    block_n: int
    block_k: int
    num_warps: int
    num_stages: int
    group_m: int
    xcd_swizzle: int
    w_cache_modifier: str
    split_k: int
    is_persistent: bool
    idle_sms: int
    epilogue_subtile: int | None
    arch: str
    target_kernel_kwargs: dict


def max_allowable_mn(
    max_mn: int,
    m: int,
    n: int,
    split_k: int,
    ) -> int:
    return 1 if m * n >= max_mn else split_k


def all_constraints_satisfied(opt_flags: OptFlags, constraints: dict) -> bool:
    _split_k_constraints = ['split_k', 'max_allowable_mn']
    assert all(getattr(opt_flags, ck) == cv for ck, cv in constraints.items() if cv is not None and ck not in _split_k_constraints)
    if constraints.get('split_k') and not constraints.get('max_allowable_mn'):
        assert opt_flags.split_k == constraints['split_k']


def make_default_opt_flags_amd(
    out_dtype,
    lhs_dtype,
    rhs_dtype,
    precision_config,
    batch_size,
    m,
    n,
    k,
    ragged_metadata,
    can_use_persistent_tma,
    can_use_split_k,
    enforce_bitwise_invariance,
    epilogue_effective_itemsize,
    x_transpose,
    has_y_acc_in,
    constraints,
):
    constraints_supported = ["block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "max_allowable_mn"]
    assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
    # tokens per slice
    if ragged_metadata is None:
        slice_size = m
    elif ragged_metadata.expected_slice_size is None:
        slice_size = max(1, m // ragged_metadata.n_slices)
    else:
        slice_size = ragged_metadata.expected_slice_size

    is_cdna4 = get_cdna_version() == 4
    # block_m
    if constraints.get("block_m", None):
        block_m = constraints["block_m"]
    elif enforce_bitwise_invariance:
        block_m = 256 if is_cdna4 else 128
    elif slice_size >= 512 and n >= 2048:
        block_m = 256 if is_cdna4 else 128
    elif is_cdna4 and m >= 512:
        block_m = 128
    else:
        block_m = max(32, min(triton.next_power_of_2(slice_size), 64))

    if ragged_metadata is not None:
        grid_m = ragged_metadata.n_blocks(ragged_metadata.n_slices, m, block_m)
    else:
        grid_m = triton.cdiv(m, block_m)
    # group_m:
    group_m = 4
    # number of xcds
    num_xcds = 8
    xcd_swizzle = num_xcds
    # block_nk:
    # TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
    block_n, block_k = opt_flags_amd.compute_block_nk(
        n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
    )
    is_persistent = constraints.get("is_persistent", False)
    # split_k:
    split_k = 1
    if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None:
        split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k"))
    elif constraints.get("split_k", None) is not None:
        split_k = constraints["split_k"]
    elif can_use_split_k and not enforce_bitwise_invariance:
        grid_size = grid_m * ((n + block_n - 1) // block_n)
        n_cu = torch.cuda.get_device_properties(0).multi_processor_count
        split_k = max(1, n_cu // grid_size)
    # w_cache_modifier:
    w_cache_modifier = ".cg" if block_m <= 32 else None
    # num_warps, num_stages
    num_warps = 2 if (m is not None and m <= 16) else 8
    num_stages = 2
    # AMD-specific
    target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1}
    epilogue_subtile = constraints.get('epilogue_subtile', None)
    if epilogue_subtile is None:
        epilogue_subtile = 1

    # prevents OutOfSharedMemoryError for mxfp8 on CDNA3
    if get_cdna_version() == 3 and bitwidth(rhs_dtype) == 8 and precision_config.b_mx_scale is not None:
        num_stages = 1

    # specific configs for F16 x MXFP4 on CDNA4
    # Note that these configs will exceed LDS usage with async copy enabled
    if is_cdna4 and bitwidth(lhs_dtype) == 16 and bitwidth(rhs_dtype) == 4 and precision_config.b_mx_scale is not None:
        split_k = 1
        if m <= 1024:
            target_kernel_kwargs["waves_per_eu"] = 3
            block_n = 128
            block_k = 256
            num_warps = 4
        else:
            target_kernel_kwargs["waves_per_eu"] = 0
            block_m = 64
            block_n = 512
            block_k = 256
            num_warps = 8

    def replace_with_valid_constraint(k: str, v):
        if constraints.get(k, None) is not None:
            return constraints[k]
        else:
            return v

    ret = OptFlags(
        block_m=replace_with_valid_constraint('block_m', block_m),
        block_n=replace_with_valid_constraint('block_n', block_n),
        block_k=replace_with_valid_constraint('block_k', block_k),
        num_warps=num_warps,
        num_stages=num_stages,
        group_m=group_m,
        xcd_swizzle=xcd_swizzle,
        w_cache_modifier=w_cache_modifier,
        split_k=split_k,
        is_persistent=is_persistent,
        idle_sms=0,
        epilogue_subtile=epilogue_subtile,
        arch=None,
        target_kernel_kwargs=target_kernel_kwargs,
    )
    # check constraints
    all_constraints_satisfied(ret, constraints)
    return ret

def make_default_opt_flags_nvidia(
    out_dtype,
    lhs_dtype,
    rhs_dtype,
    precision_config,
    batch_size,
    m,
    n,
    k,
    routing_data,
    can_use_persistent_tma,
    can_use_split_k,
    enforce_bitwise_invariance,
    epilogue_effective_itemsize,
    x_transpose,
    has_y_acc_in,
    constraints,
):
    constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn"]
    assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
    # tokens per expert
    if routing_data is None or batch_size > 1:
        slice_size = m
    elif routing_data.expected_slice_size is None:
        slice_size = max(1, m // routing_data.n_slices)
    else:
        slice_size = routing_data.expected_slice_size
    # pid swizzling
    group_m = 8
    xcd_swizzle = 1
    # block_m
    if constraints.get("block_m", None):
        block_m = constraints["block_m"]
    elif enforce_bitwise_invariance:
        block_m = 128
    else:
        if slice_size <= 64 and routing_data is not None and routing_data.slice_sizes is not None:
            # Ragged and likely memory bound; set the block size higher to minimize loading weights more than once.
            if (
                lhs_dtype == torch.bfloat16
                and rhs_dtype == FP4
                and slice_size >= 16
                and torch.cuda.get_device_capability()[0] >= 10
            ):
                block_m = max(16, min(triton.next_power_of_2(8 * slice_size), 128))
            else:
                block_m = max(16, min(triton.next_power_of_2(2 * slice_size), 64))
            if block_m == 64 and precision_config.c_mx_scale is not None and rhs_dtype == FP4 and torch.cuda.get_device_capability()[0] >= 10:
                # when having both fused_activation and mxfp8 downcast in epilogue, block_m=64 causing shared memory overflow
                block_m = 128
        else:
            block_m = max(16, min(triton.next_power_of_2(slice_size), 128))
    # block n
    arch = None
    block_n, block_n_tma = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
    # is_persistent
    grid_size_tma = opt_flags_nvidia.compute_grid_size(routing_data, batch_size, m, n, block_m, block_n_tma)
    n_sms = torch.cuda.get_device_properties(0).multi_processor_count
    tiles_per_sm = grid_size_tma / n_sms
    supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
    requires_persistent = (get_layout(precision_config.a_mx_scale) is not None or get_layout(precision_config.b_mx_scale) is not None) and target_info.has_native_mxfp()
    if constraints.get("is_persistent", None) is not None:
        is_persistent = constraints["is_persistent"]
    elif requires_persistent:
        assert supports_persistent, "persistent kernel required but not supported"
        is_persistent = True
    else:
        has_simple_epilogue = precision_config.max_num_imprecise_acc is None
        is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4
        # TMA is slower for batched matmuls with small m/n/k.
        if m * n * k < 131072:
            is_persistent = False
        if (
            (b_scale_layout := get_layout(precision_config.b_mx_scale)) is not None and
            isinstance(b_scale_layout, HopperMXScaleLayout)
        ):
            # TODO: persistent kernel is currently slower than non-persistent
            is_persistent = False
    # adjust block_n based on is_persistent signal
    block_n = block_n_tma if is_persistent else block_n
    # adjust block_m based on is_persistent signal
    if is_persistent and opt_flags_nvidia.is_x_scale_swizzled(precision_config):
        # a mx scale has been swizzled to BlackwellActMXScaleLayout, enforce block_m=128 to align with swizzling layout
        block_m = 128
    # block k
    block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in)
    if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and slice_size > 1 and lhs_dtype != torch.bfloat16:
        # Swap block_n and block_k for mxfp4 weights so that block_k is a full cacheline, so long as K is sufficiently large.
        # TODO: swizzle the HBM layout of the weights instead
        block_n, block_k = block_k, block_n
    if constraints.get("block_k", None) is not None:
        block_k = constraints["block_k"]
    # split_k
    split_k = 1
    if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None:
        split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k"))
    elif constraints.get("split_k", None) is not None:
        split_k = constraints["split_k"]
    elif can_use_split_k and not enforce_bitwise_invariance:
        estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, batch_size, m, n, block_m, block_n)
        split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
    compute_num_stages_args = (
        precision_config,
        is_persistent,
        block_m,
        block_n,
        block_k,
        torch.float32 if split_k > 1 else out_dtype,
        lhs_dtype,
        rhs_dtype,
        x_transpose,
        epilogue_effective_itemsize,
        has_y_acc_in,
    )

    if constraints.get("epilogue_subtile", None) is not None:
        subtiles_to_check = [constraints["epilogue_subtile"]]
    else:
        subtiles_to_check = [1, 2, 4]
    num_stages = -1
    for ep in subtiles_to_check:
        ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, epilogue_subtile=ep)
        if ns > num_stages:
            epilogue_subtile, num_stages = ep, ns

    if constraints.get("num_stages", None):
        num_stages = constraints["num_stages"]
    assert num_stages >= 1
    # Handshake with the HBM swizzling
    num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, is_persistent, precision_config)
    ret = OptFlags(
        block_m=block_m,
        block_n=block_n,
        block_k=block_k,
        num_warps=num_warps,
        num_stages=num_stages,
        group_m=group_m,
        xcd_swizzle=xcd_swizzle,
        w_cache_modifier=None,
        split_k=split_k,
        is_persistent=is_persistent,
        epilogue_subtile=epilogue_subtile,
        arch=arch,
        target_kernel_kwargs=dict(),
        idle_sms=constraints.get("idle_sms", 0),
    )
    # check constraints
    all_constraints_satisfied(ret, constraints)
    return ret

# --------------
# User Interface
# --------------

_opt_flags_constraints: dict = dict()
_opt_flags: OptFlags | None = None

def update_opt_flags_constraints(constraints: dict[str, int]):
    global _opt_flags_constraints
    _opt_flags_constraints.update(constraints)

def reset_opt_flags_constraints():
    global _opt_flags_constraints
    _opt_flags_constraints = dict()

def reset_opt_flags():
    global _opt_flags
    _opt_flags = None

def set_opt_flags(opt_flags: OptFlags):
    global _opt_flags
    assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override"
    assert not _opt_flags, "opt_flags already set; please reset to None first"
    _opt_flags = opt_flags

class InapplicableConstraint(Exception):
    pass

def make_opt_flags(
    out_dtype,
    lhs_dtype,
    rhs_dtype,
    precision_config,
    batch_size,
    m,
    n,
    k,
    ragged_metadata,
    can_use_persistent_tma,
    can_use_split_k,
    epilogue_effective_itemsize,
    x_transpose,
    has_y_acc_in,
    block_k,
):
    if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
        raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
    if _opt_flags_constraints.get("split_k") is not None and _opt_flags_constraints.get("split_k") > 1 and not can_use_split_k:
        raise InapplicableConstraint("cannot enforce `split_k=True` constraint")
    if _opt_flags_constraints.get("max_allowable_mn"):
        if not _opt_flags_constraints.get("split_k"):
            raise InapplicableConstraint("split_k also needs to be provided with max_allowable_mn")
    enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
    if _opt_flags is not None:
        assert not _opt_flags_constraints
        assert block_k is None
        return _opt_flags
    opt_flags_constraints = _opt_flags_constraints
    if block_k is not None:
        opt_flags_constraints = opt_flags_constraints.copy()
        opt_flags_constraints.update(block_k=block_k, split_k=1)
    args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, batch_size, m, n, k,
            ragged_metadata, can_use_persistent_tma, can_use_split_k,
            enforce_bitwise_invariance, epilogue_effective_itemsize, x_transpose, has_y_acc_in,
            opt_flags_constraints]
    backend = triton.runtime.driver.active.get_current_target().backend
    if backend == "hip":
        return make_default_opt_flags_amd(*args)
    if backend == "cuda":
        return make_default_opt_flags_nvidia(*args)
    assert False
